package org.zjvis.datascience.service.dag;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.enums.TaskInstanceStatus;

import java.util.*;

/**
 * @description 任务管理对象
 * @date 2021-12-01
 */
@Service
public class TaskManager {

    private final static Logger logger = LoggerFactory.getLogger("TaskManager");

    private Map<Long, List<TaskFuture>> taskHolder = Maps.newConcurrentMap();

    private Map<Long, Long> timeHolder = Maps.newConcurrentMap();

    private Set<Long> samplingSessions = Sets.newConcurrentHashSet();

    // 保存session下dag节点的信息
    private Map<Long, Map<Long, TaskInstanceDTO>> dagTaskMap = Maps.newConcurrentMap();

    public Map<Long, List<TaskFuture>> getTaskHolder() {
        return taskHolder;
    }

    public Map<Long, Long> getTimeHolder() {
        return timeHolder;
    }

    /**
     * return sessionId from current timeHolder
     *
     * @param id
     * @return
     */
    public Long getSId(Long id) {
        return timeHolder.getOrDefault(id, -1L);
    }

    public void putTaskHolder(Long sessionId, List<TaskFuture> taskFutures) {
        this.taskHolder.put(sessionId, taskFutures);
    }

    public void addSamplingSession(Long session) {
        synchronized (samplingSessions) {
            samplingSessions.add(session);
        }
    }

    public Set<Long> getSamplingSessions() {
        return samplingSessions;
    }

    public void removeSessionForSet(Long session) {
        if (samplingSessions.contains(session)) {
            synchronized (samplingSessions) {
                samplingSessions.remove(session);
            }
        }
    }

    public void putTimeHolder(Long sessionId, Long timeStamp) {
        this.timeHolder.put(sessionId, timeStamp);
    }

    public void removeTaskHolder(Long sessionId) {
        this.taskHolder.remove(sessionId);
    }

    public void removeTimeHolder(Long sessionId) {
        this.timeHolder.remove(sessionId);
    }

    public void removeDagTaskMap(Long sessionId) {
        this.dagTaskMap.remove(sessionId);
    }

    public void remove(Long sessionId) {
        synchronized (this) {
            this.removeTaskHolder(sessionId);
            this.removeTimeHolder(sessionId);
            this.removeDagTaskMap(sessionId);
            this.removeSessionForSet(sessionId);
        }
    }

    public void putDagTaskMap(Long sessionId, Map<Long, TaskInstanceDTO> map) {
        this.dagTaskMap.put(sessionId, map);
    }

    public void updateDagTaskMap(Long sessionId, TaskInstanceDTO instanceDTO) {
        synchronized (this) {
            Map<Long, TaskInstanceDTO> map = this.dagTaskMap.get(sessionId);
            if (map.containsKey(instanceDTO.getId())) {
                map.put(instanceDTO.getId(), instanceDTO);
            } else {
                logger.error("updateDagTaskMap fail!!!!");
            }
        }
    }

    /**
     * 查看当前简单的先辈节点在本次执行的时候是否存在失败情况，如果是的话那么当前节点状态需要设置为fail
     *
     * @param sessionId
     * @param instanceDTO
     * @return
     */
    public synchronized boolean isAncestorFail(Long sessionId, TaskInstanceDTO instanceDTO) {
        Map<Long, TaskInstanceDTO> map = this.dagTaskMap.get(sessionId);
        Queue<TaskInstanceDTO> queue = new ArrayDeque<>();
        queue.offer(instanceDTO);
        boolean isFail = false;
        while (!queue.isEmpty()) {
            TaskInstanceDTO dto = queue.poll();
            if (null != dto) {//高并发下 dto可能为空
                if (TaskInstanceStatus.isFailed(dto.getStatus())) {
                    isFail = true;
                    break;
                }
                List<Long> parentIdList = dto.getParentIdList();
                for (Long parentId : parentIdList) {
                    if (map.containsKey(parentId) && map.get(parentId) != null) {
                        queue.offer(map.get(parentId));
                    }
                }
            } else {
                break;
            }
        }
        return isFail;
    }

    /**
     * 判断指定session是否还有任务在跑
     *
     * @param sessionId
     * @return
     */
    public synchronized boolean isRunning(Long sessionId) {
        boolean isStillRun = false;
        Map<Long, TaskInstanceDTO> map = this.dagTaskMap.get(sessionId);
        for (Map.Entry<Long, TaskInstanceDTO> entry : map.entrySet()) {
            if (TaskInstanceStatus.isRunning(entry.getValue().getStatus())) {
                isStillRun = true;
                break;
            }
        }
        return isStillRun;
    }

    /**
     * 判断指定session中 指定的task否还有任务在跑
     *
     * @param sessionId
     * @return
     */
    public synchronized boolean isStillRunning(Long sessionId, Long taskId) {
        boolean isStillRun = false;
        Map<Long, TaskInstanceDTO> map = this.dagTaskMap.get(sessionId);
        for (Map.Entry<Long, TaskInstanceDTO> entry : map.entrySet()) {
            TaskInstanceDTO dto = entry.getValue();
            if (dto.getTaskId().equals(taskId)) {
                if (TaskInstanceStatus.isRunning(dto.getStatus())) {
                    isStillRun = true;
                    break;
                }
            }
        }
        return isStillRun;
    }

}
